// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "common.h"

// y = alpha*A*x + beta*y
int
EIGEN_BLAS_FUNC(symv)(const char* uplo,
					  const int* n,
					  const RealScalar* palpha,
					  const RealScalar* pa,
					  const int* lda,
					  const RealScalar* px,
					  const int* incx,
					  const RealScalar* pbeta,
					  RealScalar* py,
					  const int* incy)
{
	typedef void (*functype)(int, const Scalar*, int, const Scalar*, Scalar*, Scalar);
	static const functype func[2] = {
		// array index: UP
		(internal::selfadjoint_matrix_vector_product<Scalar, int, ColMajor, Upper, false, false>::run),
		// array index: LO
		(internal::selfadjoint_matrix_vector_product<Scalar, int, ColMajor, Lower, false, false>::run),
	};

	const Scalar* a = reinterpret_cast<const Scalar*>(pa);
	const Scalar* x = reinterpret_cast<const Scalar*>(px);
	Scalar* y = reinterpret_cast<Scalar*>(py);
	Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);
	Scalar beta = *reinterpret_cast<const Scalar*>(pbeta);

	// check arguments
	int info = 0;
	if (UPLO(*uplo) == INVALID)
		info = 1;
	else if (*n < 0)
		info = 2;
	else if (*lda < std::max(1, *n))
		info = 5;
	else if (*incx == 0)
		info = 7;
	else if (*incy == 0)
		info = 10;
	if (info)
		return xerbla_(SCALAR_SUFFIX_UP "SYMV ", &info, 6);

	if (*n == 0)
		return 0;

	const Scalar* actual_x = get_compact_vector(x, *n, *incx);
	Scalar* actual_y = get_compact_vector(y, *n, *incy);

	if (beta != Scalar(1)) {
		if (beta == Scalar(0))
			make_vector(actual_y, *n).setZero();
		else
			make_vector(actual_y, *n) *= beta;
	}

	int code = UPLO(*uplo);
	if (code >= 2 || func[code] == 0)
		return 0;

	func[code](*n, a, *lda, actual_x, actual_y, alpha);

	if (actual_x != x)
		delete[] actual_x;
	if (actual_y != y)
		delete[] copy_back(actual_y, y, *n, *incy);

	return 1;
}

// C := alpha*x*x' + C
int
EIGEN_BLAS_FUNC(syr)(const char* uplo,
					 const int* n,
					 const RealScalar* palpha,
					 const RealScalar* px,
					 const int* incx,
					 RealScalar* pc,
					 const int* ldc)
{

	typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, const Scalar&);
	static const functype func[2] = {
		// array index: UP
		(selfadjoint_rank1_update<Scalar, int, ColMajor, Upper, false, Conj>::run),
		// array index: LO
		(selfadjoint_rank1_update<Scalar, int, ColMajor, Lower, false, Conj>::run),
	};

	const Scalar* x = reinterpret_cast<const Scalar*>(px);
	Scalar* c = reinterpret_cast<Scalar*>(pc);
	Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);

	int info = 0;
	if (UPLO(*uplo) == INVALID)
		info = 1;
	else if (*n < 0)
		info = 2;
	else if (*incx == 0)
		info = 5;
	else if (*ldc < std::max(1, *n))
		info = 7;
	if (info)
		return xerbla_(SCALAR_SUFFIX_UP "SYR  ", &info, 6);

	if (*n == 0 || alpha == Scalar(0))
		return 1;

	// if the increment is not 1, let's copy it to a temporary vector to enable vectorization
	const Scalar* x_cpy = get_compact_vector(x, *n, *incx);

	int code = UPLO(*uplo);
	if (code >= 2 || func[code] == 0)
		return 0;

	func[code](*n, c, *ldc, x_cpy, x_cpy, alpha);

	if (x_cpy != x)
		delete[] x_cpy;

	return 1;
}

// C := alpha*x*y' + alpha*y*x' + C
int
EIGEN_BLAS_FUNC(syr2)(const char* uplo,
					  const int* n,
					  const RealScalar* palpha,
					  const RealScalar* px,
					  const int* incx,
					  const RealScalar* py,
					  const int* incy,
					  RealScalar* pc,
					  const int* ldc)
{
	typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
	static const functype func[2] = {
		// array index: UP
		(internal::rank2_update_selector<Scalar, int, Upper>::run),
		// array index: LO
		(internal::rank2_update_selector<Scalar, int, Lower>::run),
	};

	const Scalar* x = reinterpret_cast<const Scalar*>(px);
	const Scalar* y = reinterpret_cast<const Scalar*>(py);
	Scalar* c = reinterpret_cast<Scalar*>(pc);
	Scalar alpha = *reinterpret_cast<const Scalar*>(palpha);

	int info = 0;
	if (UPLO(*uplo) == INVALID)
		info = 1;
	else if (*n < 0)
		info = 2;
	else if (*incx == 0)
		info = 5;
	else if (*incy == 0)
		info = 7;
	else if (*ldc < std::max(1, *n))
		info = 9;
	if (info)
		return xerbla_(SCALAR_SUFFIX_UP "SYR2 ", &info, 6);

	if (alpha == Scalar(0))
		return 1;

	const Scalar* x_cpy = get_compact_vector(x, *n, *incx);
	const Scalar* y_cpy = get_compact_vector(y, *n, *incy);

	int code = UPLO(*uplo);
	if (code >= 2 || func[code] == 0)
		return 0;

	func[code](*n, c, *ldc, x_cpy, y_cpy, alpha);

	if (x_cpy != x)
		delete[] x_cpy;
	if (y_cpy != y)
		delete[] y_cpy;

	//   int code = UPLO(*uplo);
	//   if(code>=2 || func[code]==0)
	//     return 0;

	//   func[code](*n, a, *inca, b, *incb, c, *ldc, alpha);
	return 1;
}

/**  DSBMV  performs the matrix-vector  operation
 *
 *     y := alpha*A*x + beta*y,
 *
 *  where alpha and beta are scalars, x and y are n element vectors and
 *  A is an n by n symmetric band matrix, with k super-diagonals.
 */
// int EIGEN_BLAS_FUNC(sbmv)( char *uplo, int *n, int *k, RealScalar *alpha, RealScalar *a, int *lda,
//                            RealScalar *x, int *incx, RealScalar *beta, RealScalar *y, int *incy)
// {
//   return 1;
// }

/**  DSPMV  performs the matrix-vector operation
 *
 *     y := alpha*A*x + beta*y,
 *
 *  where alpha and beta are scalars, x and y are n element vectors and
 *  A is an n by n symmetric matrix, supplied in packed form.
 *
 */
// int EIGEN_BLAS_FUNC(spmv)(char *uplo, int *n, RealScalar *alpha, RealScalar *ap, RealScalar *x, int *incx, RealScalar
// *beta, RealScalar *y, int *incy)
// {
//   return 1;
// }

/**  DSPR    performs the symmetric rank 1 operation
 *
 *     A := alpha*x*x' + A,
 *
 *  where alpha is a real scalar, x is an n element vector and A is an
 *  n by n symmetric matrix, supplied in packed form.
 */
int
EIGEN_BLAS_FUNC(spr)(char* uplo, int* n, Scalar* palpha, Scalar* px, int* incx, Scalar* pap)
{
	typedef void (*functype)(int, Scalar*, const Scalar*, Scalar);
	static const functype func[2] = {
		// array index: UP
		(internal::selfadjoint_packed_rank1_update<Scalar, int, ColMajor, Upper, false, false>::run),
		// array index: LO
		(internal::selfadjoint_packed_rank1_update<Scalar, int, ColMajor, Lower, false, false>::run),
	};

	Scalar* x = reinterpret_cast<Scalar*>(px);
	Scalar* ap = reinterpret_cast<Scalar*>(pap);
	Scalar alpha = *reinterpret_cast<Scalar*>(palpha);

	int info = 0;
	if (UPLO(*uplo) == INVALID)
		info = 1;
	else if (*n < 0)
		info = 2;
	else if (*incx == 0)
		info = 5;
	if (info)
		return xerbla_(SCALAR_SUFFIX_UP "SPR  ", &info, 6);

	if (alpha == Scalar(0))
		return 1;

	Scalar* x_cpy = get_compact_vector(x, *n, *incx);

	int code = UPLO(*uplo);
	if (code >= 2 || func[code] == 0)
		return 0;

	func[code](*n, ap, x_cpy, alpha);

	if (x_cpy != x)
		delete[] x_cpy;

	return 1;
}

/**  DSPR2  performs the symmetric rank 2 operation
 *
 *     A := alpha*x*y' + alpha*y*x' + A,
 *
 *  where alpha is a scalar, x and y are n element vectors and A is an
 *  n by n symmetric matrix, supplied in packed form.
 */
int
EIGEN_BLAS_FUNC(
	spr2)(char* uplo, int* n, RealScalar* palpha, RealScalar* px, int* incx, RealScalar* py, int* incy, RealScalar* pap)
{
	typedef void (*functype)(int, Scalar*, const Scalar*, const Scalar*, Scalar);
	static const functype func[2] = {
		// array index: UP
		(internal::packed_rank2_update_selector<Scalar, int, Upper>::run),
		// array index: LO
		(internal::packed_rank2_update_selector<Scalar, int, Lower>::run),
	};

	Scalar* x = reinterpret_cast<Scalar*>(px);
	Scalar* y = reinterpret_cast<Scalar*>(py);
	Scalar* ap = reinterpret_cast<Scalar*>(pap);
	Scalar alpha = *reinterpret_cast<Scalar*>(palpha);

	int info = 0;
	if (UPLO(*uplo) == INVALID)
		info = 1;
	else if (*n < 0)
		info = 2;
	else if (*incx == 0)
		info = 5;
	else if (*incy == 0)
		info = 7;
	if (info)
		return xerbla_(SCALAR_SUFFIX_UP "SPR2 ", &info, 6);

	if (alpha == Scalar(0))
		return 1;

	Scalar* x_cpy = get_compact_vector(x, *n, *incx);
	Scalar* y_cpy = get_compact_vector(y, *n, *incy);

	int code = UPLO(*uplo);
	if (code >= 2 || func[code] == 0)
		return 0;

	func[code](*n, ap, x_cpy, y_cpy, alpha);

	if (x_cpy != x)
		delete[] x_cpy;
	if (y_cpy != y)
		delete[] y_cpy;

	return 1;
}

/**  DGER   performs the rank 1 operation
 *
 *     A := alpha*x*y' + A,
 *
 *  where alpha is a scalar, x is an m element vector, y is an n element
 *  vector and A is an m by n matrix.
 */
int
EIGEN_BLAS_FUNC(ger)(int* m, int* n, Scalar* palpha, Scalar* px, int* incx, Scalar* py, int* incy, Scalar* pa, int* lda)
{
	Scalar* x = reinterpret_cast<Scalar*>(px);
	Scalar* y = reinterpret_cast<Scalar*>(py);
	Scalar* a = reinterpret_cast<Scalar*>(pa);
	Scalar alpha = *reinterpret_cast<Scalar*>(palpha);

	int info = 0;
	if (*m < 0)
		info = 1;
	else if (*n < 0)
		info = 2;
	else if (*incx == 0)
		info = 5;
	else if (*incy == 0)
		info = 7;
	else if (*lda < std::max(1, *m))
		info = 9;
	if (info)
		return xerbla_(SCALAR_SUFFIX_UP "GER  ", &info, 6);

	if (alpha == Scalar(0))
		return 1;

	Scalar* x_cpy = get_compact_vector(x, *m, *incx);
	Scalar* y_cpy = get_compact_vector(y, *n, *incy);

	internal::general_rank1_update<Scalar, int, ColMajor, false, false>::run(*m, *n, a, *lda, x_cpy, y_cpy, alpha);

	if (x_cpy != x)
		delete[] x_cpy;
	if (y_cpy != y)
		delete[] y_cpy;

	return 1;
}
